{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e6ac660a-892d-4aae-a936-e65e2b999c28",
   "metadata": {
    "papermill": {
     "duration": 0.001709,
     "end_time": "2023-07-14T17:41:26.796342",
     "exception": false,
     "start_time": "2023-07-14T17:41:26.794633",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Segment-Anything"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f223aaf-3a0a-4f71-a6e1-4b3de62fe4f6",
   "metadata": {
    "papermill": {
     "duration": 0.008011,
     "end_time": "2023-07-14T17:41:26.854570",
     "exception": false,
     "start_time": "2023-07-14T17:41:26.846559",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "Generate mask for an given image using Segmenting Anything library.\n",
    "https://github.com/facebookresearch/segment-anything\n",
    "Output an numpy datafile in specified output directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d1936b3c-0e07-45d1-85df-4c699aa13a08",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-14T17:41:26.863370Z",
     "iopub.status.busy": "2023-07-14T17:41:26.863370Z",
     "iopub.status.idle": "2023-07-14T17:41:51.819504Z",
     "shell.execute_reply": "2023-07-14T17:41:51.819504Z"
    },
    "papermill": {
     "duration": 24.964934,
     "end_time": "2023-07-14T17:41:51.819504",
     "exception": false,
     "start_time": "2023-07-14T17:41:26.854570",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting git+https://github.com/facebookresearch/segment-anything.git\n",
      "  Cloning https://github.com/facebookresearch/segment-anything.git to c:\\users\\colin\\appdata\\local\\temp\\pip-req-build-dzviiqsr\n",
      "  Resolved https://github.com/facebookresearch/segment-anything.git to commit 6fdee8f2727f4506cfbbe553e23b895e27956588\n",
      "  Preparing metadata (setup.py): started\n",
      "  Preparing metadata (setup.py): finished with status 'done'\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git 'C:\\Users\\colin\\AppData\\Local\\Temp\\pip-req-build-dzviiqsr'\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: opencv-python in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (4.8.0.74)\n",
      "Requirement already satisfied: numpy>=1.21.2 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from opencv-python) (1.25.0)\n",
      "Requirement already satisfied: torch in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (2.0.1)\n",
      "Requirement already satisfied: filelock in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch) (3.12.2)\n",
      "Requirement already satisfied: typing-extensions in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch) (4.6.3)\n",
      "Requirement already satisfied: sympy in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch) (1.12)\n",
      "Requirement already satisfied: networkx in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch) (3.1)\n",
      "Requirement already satisfied: jinja2 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch) (3.1.2)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from jinja2->torch) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from sympy->torch) (1.3.0)\n",
      "Requirement already satisfied: torchvision in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (0.15.2)\n",
      "Requirement already satisfied: numpy in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torchvision) (1.25.0)\n",
      "Requirement already satisfied: requests in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torchvision) (2.28.1)\n",
      "Requirement already satisfied: torch==2.0.1 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torchvision) (2.0.1)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torchvision) (10.0.0)\n",
      "Requirement already satisfied: filelock in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch==2.0.1->torchvision) (3.12.2)\n",
      "Requirement already satisfied: typing-extensions in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch==2.0.1->torchvision) (4.6.3)\n",
      "Requirement already satisfied: sympy in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch==2.0.1->torchvision) (1.12)\n",
      "Requirement already satisfied: networkx in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch==2.0.1->torchvision) (3.1)\n",
      "Requirement already satisfied: jinja2 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from torch==2.0.1->torchvision) (3.1.2)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from requests->torchvision) (2.1.1)\n",
      "Requirement already satisfied: idna<4,>=2.5 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from requests->torchvision) (3.4)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from requests->torchvision) (1.26.16)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from requests->torchvision) (2023.5.7)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from jinja2->torch==2.0.1->torchvision) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in d:\\python_projects\\anaconda\\envs\\claimed\\lib\\site-packages (from sympy->torch==2.0.1->torchvision) (1.3.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install git+https://github.com/facebookresearch/segment-anything.git\n",
    "!pip install opencv-python\n",
    "!pip install torch\n",
    "!pip install torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "300597ed-b9dc-4fdb-9cd4-da793c1aa5c5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-14T17:41:51.827704Z",
     "iopub.status.busy": "2023-07-14T17:41:51.827704Z",
     "iopub.status.idle": "2023-07-14T17:41:58.049120Z",
     "shell.execute_reply": "2023-07-14T17:41:58.049120Z"
    },
    "papermill": {
     "duration": 6.231825,
     "end_time": "2023-07-14T17:41:58.051329",
     "exception": false,
     "start_time": "2023-07-14T17:41:51.819504",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "import sys\n",
    "import re\n",
    "import cv2\n",
    "import numpy as np\n",
    "from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "29c527c9-ac90-4d98-a152-2b2301588aad",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-14T17:41:58.059501Z",
     "iopub.status.busy": "2023-07-14T17:41:58.059501Z",
     "iopub.status.idle": "2023-07-14T17:41:58.068369Z",
     "shell.execute_reply": "2023-07-14T17:41:58.068369Z"
    },
    "papermill": {
     "duration": 0.01704,
     "end_time": "2023-07-14T17:41:58.068369",
     "exception": false,
     "start_time": "2023-07-14T17:41:58.051329",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# model type\n",
    "model_type = os.environ.get(\"model_type\", \"vit_h\") # default, vit_h, vit_l or vit_b\n",
    "\n",
    "# checkpoint\n",
    "# different model type requires different check point\n",
    "checkpoint_path = os.environ.get(\"checkpoint_path\")\n",
    "\n",
    "# input image\n",
    "input_image_path = os.environ.get(\"input_image_path\")\n",
    "\n",
    "# temporal data storage for local execution\n",
    "data_dir = os.environ.get('data_dir', '../../data/')\n",
    "\n",
    "# dummy_output (to be fixed once C3 supports < 1 outputs)\n",
    "# output_dummy = os.environ.get('output_dummy', 'output_dummy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "81032697-8712-4f91-99d5-66351a4341d3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-14T17:41:58.077312Z",
     "iopub.status.busy": "2023-07-14T17:41:58.077312Z",
     "iopub.status.idle": "2023-07-14T17:41:58.090562Z",
     "shell.execute_reply": "2023-07-14T17:41:58.090562Z"
    },
    "papermill": {
     "duration": 0.01325,
     "end_time": "2023-07-14T17:41:58.090562",
     "exception": false,
     "start_time": "2023-07-14T17:41:58.077312",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "parameters = list(\n",
    "    map(lambda s: re.sub('$', '\"', s),\n",
    "        map(\n",
    "            lambda s: s.replace('=', '=\"'),\n",
    "            filter(\n",
    "                lambda s: s.find('=') > -1 and bool(re.match(r'[A-Za-z0-9_]*=[.\\/A-Za-z0-9]*', s)),\n",
    "                sys.argv\n",
    "            )\n",
    "    )))\n",
    "\n",
    "for parameter in parameters:\n",
    "    logging.warning('Parameter: ' + parameter)\n",
    "    exec(parameter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "df9f546d-a60e-47b7-b413-a302077dcb92",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-14T17:41:58.098756Z",
     "iopub.status.busy": "2023-07-14T17:41:58.098756Z",
     "iopub.status.idle": "2023-07-14T17:47:02.806490Z",
     "shell.execute_reply": "2023-07-14T17:47:02.806490Z"
    },
    "papermill": {
     "duration": 304.715928,
     "end_time": "2023-07-14T17:47:02.806490",
     "exception": false,
     "start_time": "2023-07-14T17:41:58.090562",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# get masks from a given prompt\n",
    "device = \"cuda\"\n",
    "sam = sam_model_registry[model_type](checkpoint=checkpoint_path)\n",
    "input_image = cv2.imread(input_image_path)\n",
    "\n",
    "# generate masks for image\n",
    "mask_generator = SamAutomaticMaskGenerator(sam)\n",
    "masks = mask_generator.generate(input_image)\n",
    "\n",
    "\n",
    "# save generated masks into given directory\n",
    "output_directory, output_filename = os.path.split(data_dir)\n",
    "os.makedirs(output_directory, exist_ok=True)\n",
    "np.save(data_dir + \"generate_masks\", masks)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 340.340719,
   "end_time": "2023-07-14T17:47:03.918961",
   "environment_variables": {},
   "exception": null,
   "input_path": "D:\\Python_projects\\Claimed\\component-library\\component-library\\segment-anything\\generate-masks.ipynb",
   "output_path": "D:\\Python_projects\\Claimed\\component-library\\component-library\\segment-anything\\generate-masks.ipynb",
   "parameters": {},
   "start_time": "2023-07-14T17:41:23.578242",
   "version": "2.4.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
